"""
nuScenes雷达点云分屏可视化脚本
功能：在五个独立的3D子图中分别显示同一场景下不同雷达的点云
环境要求：Python 3.6+，需安装以下包：
  pip install nuscenes-devkit matplotlib numpy pyquaternion
"""
import os
import argparse
import numpy as np
import matplotlib.pyplot as plt
from nuscenes import NuScenes
from pyquaternion import Quaternion
from nuscenes.utils.data_classes import LidarPointCloud, RadarPointCloud

# 配置参数
DATASET_ROOT = "./mini"  # 必须修改为实际数据集路径
SCENE_INDEX = 0                                # 场景索引（0表示第一个场景）
MAX_DISTANCE = 100.0                           # 最大可视距离（米）
POINT_SIZE = 1                                # 点云大小
SUBPLOT_SIZE = (16, 10)                       # 画布尺寸（英寸）
VIEW_ELEV = 30                                # 俯仰角
VIEW_AZIM = -90                               # 方位角

# 雷达传感器列表（nuScenes mini包含的五个雷达）
RADAR_NAMES = [
    "RADAR_FRONT",
    "RADAR_FRONT_LEFT",
    "RADAR_FRONT_RIGHT",
    "RADAR_BACK_LEFT",
    "RADAR_BACK_RIGHT",
]

def load_radar_data(nusc, sample_token, radar_name):
    """加载并转换指定雷达的点云数据到全局坐标系"""
    # 获取雷达数据记录
    radar_data = nusc.get("sample_data", sample_token["data"][radar_name])
    
    # 加载点云二进制文件
    file_path = os.path.join(nusc.dataroot, radar_data["filename"])
    if not os.path.exists(file_path):
        raise FileNotFoundError(f"雷达文件不存在：{file_path}")
    pc = RadarPointCloud.from_file(file_path)
    pc = pc.points.T 

    # # 坐标系转换：雷达坐标系 -> 全局坐标系
    # calib = nusc.get("calibrated_sensor", radar_data["calibrated_sensor_token"])
    # ego_pose = nusc.get("ego_pose", radar_data["ego_pose_token"])
    
    # points = np.vstack((pc[:3, :], np.ones(pc.shape[1])))
    # points_vehicle = calib["rotation"].rotation_matrix @ points[:3, :] + np.array(calib["translation"]).reshape(-1, 1)
    # points_global = ego_pose["rotation"].rotation_matrix @ points_vehicle + np.array(ego_pose["translation"]).reshape(-1, 1)
    
    return pc[:, :3]  # (num_points, 3)

def create_subplots():
    """创建2x3布局的子图（第五个位置留空）"""
    fig = plt.figure(figsize=SUBPLOT_SIZE)
    axes = [
        fig.add_subplot(231, projection='3d'),
        fig.add_subplot(232, projection='3d'),
        fig.add_subplot(233, projection='3d'),
        fig.add_subplot(234, projection='3d'),
        fig.add_subplot(235, projection='3d')
    ]
    return fig, axes

def configure_axes(ax, title):
    """统一配置子图参数"""
    ax.set_xlim(-MAX_DISTANCE, MAX_DISTANCE)
    ax.set_ylim(-MAX_DISTANCE, MAX_DISTANCE)
    ax.set_zlim(-5, 5)  # 限制Z轴显示高度
    ax.view_init(elev=VIEW_ELEV, azim=VIEW_AZIM)
    ax.set_title(title, fontsize=10)
    ax.set_xlabel('X (m)', fontsize=8)
    ax.set_ylabel('Y (m)', fontsize=8)
    ax.set_zlabel('Z (m)', fontsize=8)
    ax.xaxis.set_tick_params(labelsize=6)
    ax.yaxis.set_tick_params(labelsize=6)
    ax.zaxis.set_tick_params(labelsize=6)

def plot_single_radar(ax, points, radar_name):
    """在指定子图中绘制单个雷达点云"""
    # 应用距离过滤
    valid_points = points[np.linalg.norm(points, axis=1) < MAX_DISTANCE]
    
    # 提取坐标分量
    x = valid_points[:, 0]
    y = valid_points[:, 1]
    z = valid_points[:, 2]
    
    # 绘制点云
    ax.scatter(x, y, z, 
               c=z,  # 使用Z轴值作为颜色映射
               cmap='viridis',
               s=POINT_SIZE, 
               alpha=0.8,
               depthshade=False)
    
    # # 绘制雷达位置标记
    # ax.scatter(0, 0, 0, 
    #            c='red', 
    #            marker='^', 
    #            s=50, 
    #            label='Sensor Position')
    ax.legend(loc='upper right', fontsize=6)

def main():
    # 初始化数据集
    nusc = NuScenes(version="v1.0-mini", dataroot=DATASET_ROOT, verbose=False)
    
    # 获取场景信息
    scene = nusc.scene[SCENE_INDEX]
    sample = nusc.get("sample", scene["first_sample_token"])
    
    # 创建画布和子图
    fig, axes = create_subplots()
    plt.suptitle(f"Radar Point Clouds - Scene: {scene['name']}", y=0.95)
    
    # 逐个处理雷达数据
    for idx, radar_name in enumerate(RADAR_NAMES):
        try:
            points = load_radar_data(nusc, sample, radar_name)
            configure_axes(axes[idx], radar_name)
            plot_single_radar(axes[idx], points, radar_name)
        except Exception as e:
            print(f"Error processing {radar_name}: {str(e)}")
            continue
    
    # 调整布局并显示
    plt.tight_layout()
    plt.subplots_adjust(top=0.85)  # 给主标题留出空间
    plt.show()

if __name__ == "__main__":
    main()